/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/* !
 * \file kernel_micro_datacopy_impl.h
 * \brief
 */
#ifndef ASCENDC_MODULE_MICRO_DATACOPY_IMPL_H
#define ASCENDC_MODULE_MICRO_DATACOPY_IMPL_H
#include "micro_api/dav_c310/kernel_micro_datacopy_load_impl.h"
#include "micro_api/dav_c310/kernel_micro_datacopy_store_impl.h"

namespace AscendC {
namespace MicroAPI {
template <typename DstT = DefaultType, typename SrcT, typename IndexT = DefaultType, typename RegDstT,
    typename RegIndexT>
__simd_callee__ inline void DataCopyGatherB64Impl(
    RegDstT &dstReg, __local_mem__ SrcT *baseAddr, RegIndexT &index, MaskReg &mask)
{
    // index u32
    if constexpr (CheckRegTrait<RegDstT, RegTraitNumOne>()) {
        MaskReg dstMask;
        RegTensor<uint32_t> oddIndex;
        RegTensor<uint32_t> evenIndex;
        RegTensor<uint32_t> oddReg;
        RegTensor<uint32_t> evenReg;
        RegTensor<uint32_t> tmpReg;
        MaskPack(dstMask, mask);
        MaskReg lowerMask = CreateMask<uint32_t, MaskPattern::VL32>();
        MaskReg preg = CreateMask<uint32_t, MaskPattern::ALL>();
        MaskAnd(dstMask, dstMask, lowerMask, preg);
        Muls(oddIndex, index, uint32_t(2), dstMask);
        Adds(evenIndex, oddIndex, uint32_t(1), dstMask);
        vgather2(oddReg, (__ubuf__ uint32_t *)baseAddr, oddIndex, dstMask);
        vgather2(evenReg, (__ubuf__ uint32_t *)baseAddr, evenIndex, dstMask);
        Interleave((RegTensor<uint32_t> &)dstReg, tmpReg, oddReg, evenReg);
    } else {
        RegTensor<uint32_t> oddIndex;
        RegTensor<uint32_t> evenIndex;
        Muls(oddIndex, index, uint32_t(2), mask);
        Adds(evenIndex, oddIndex, uint32_t(1), mask);
        vgather2((RegTensor<uint32_t> &)dstReg.reg[0], (__ubuf__ uint32_t *)baseAddr, oddIndex, mask);
        vgather2((RegTensor<uint32_t> &)dstReg.reg[1], (__ubuf__ uint32_t *)baseAddr, evenIndex, mask);
    }
}

// vgather2
template <typename DstT = DefaultType, typename SrcT, typename IndexT = DefaultType, typename RegDstT,
    typename RegIndexT>
__simd_callee__ inline void DataCopyGatherImpl(
    RegDstT &dstReg, __local_mem__ SrcT *baseAddr, RegIndexT &index, MaskReg &mask)
{
    using ActualDstT = typename RegDstT::ActualT;
    using ActualIndexT = typename RegIndexT::ActualT;
    static_assert(std::is_same_v<DstT, DefaultType> || std::is_same_v<DstT, ActualDstT>, "DstT type is not correct!");
    static_assert(
        std::is_same_v<IndexT, DefaultType> || std::is_same_v<IndexT, ActualIndexT>, "IndexT type is not correct!");
    static_assert((sizeof(SrcT) == 1 && sizeof(ActualDstT) == 2 && std::is_same_v<ActualIndexT, uint16_t>) ||
                      (sizeof(SrcT) == 2 && sizeof(ActualDstT) == 2 && std::is_same_v<ActualIndexT, uint16_t>) ||
                      (sizeof(SrcT) == 4 && sizeof(ActualDstT) == 4 && std::is_same_v<ActualIndexT, uint32_t>) ||
                      (sizeof(SrcT) == 8 && sizeof(ActualDstT) == 8 && SupportType<ActualIndexT, uint32_t, uint64_t>()),
        "DataCopyGather only support src data type b8/b16/b32/b64 with dst type is b16/b16/b32/b64 respectively and "
        "each index type is u16/u16/u32/(u32/u64) respectively on current device");
    // when index RegIndexT<b64, 1> only 32 element valid not support RegDstT<b64, 2> mode
    static_assert(!(sizeof(SrcT) == 8 && std::is_same_v<ActualIndexT, uint64_t> &&
                      CheckRegTrait<RegIndexT, RegTraitNumOne>() && CheckRegTrait<RegDstT, RegTraitNumTwo>()),
        "current data type is not supported on current device!");

    if constexpr (sizeof(SrcT) == 1 && sizeof(ActualDstT) == 2) {
        vgather2((vector_s16 &)dstReg, (__ubuf__ int8_t *)baseAddr, index, mask);
    } else if constexpr (sizeof(SrcT) == 2 && sizeof(ActualDstT) == 2) {
        vgather2((vector_s16 &)dstReg, (__ubuf__ int16_t *)baseAddr, index, mask);
    } else if constexpr (sizeof(SrcT) == 4 && sizeof(ActualDstT) == 4) {
        vgather2((vector_s32 &)dstReg, (__ubuf__ int32_t *)baseAddr, index, mask);
    } else {
        if constexpr (std::is_same_v<ActualIndexT, uint32_t>) {
            DataCopyGatherB64Impl(dstReg, baseAddr, index, mask);
        } else if constexpr (std::is_same_v<ActualIndexT, uint64_t>) {
            if constexpr (CheckRegTrait<RegIndexT, RegTraitNumOne>() && CheckRegTrait<RegDstT, RegTraitNumOne>()) {
                RegTensor<uint32_t> lowIndex;
                RegTensor<uint32_t> highIndex;
                DeInterleave(lowIndex, highIndex, (RegTensor<uint32_t> &)index, (RegTensor<uint32_t> &)index);
                DataCopyGatherB64Impl(dstReg, baseAddr, lowIndex, mask);
            } else if constexpr ((CheckRegTrait<RegIndexT, RegTraitNumTwo>() &&
                                     CheckRegTrait<RegDstT, RegTraitNumOne>()) ||
                                 (CheckRegTrait<RegIndexT, RegTraitNumTwo>() &&
                                     CheckRegTrait<RegDstT, RegTraitNumTwo>())) {
                DataCopyGatherB64Impl(dstReg, baseAddr, (RegTensor<uint32_t> &)index.reg[0], mask);
            }
        }
    }
}

// vgatherb
template <typename T = DefaultType, typename RegT, typename RegIndexT>
__simd_callee__ inline void DataCopyGatherBImpl(RegT &dstReg, __local_mem__ T *baseAddr, RegIndexT &index, MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    using ActualIndexT = typename RegIndexT::ActualT;
    static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
    static_assert(std::is_same_v<ActualIndexT, uint32_t>, "IndexT type is not correct!");
    static_assert(CheckRegTrait<RegT, RegTraitNumOne>(), "RegTensor only suppoort RegTraitNumOne on current device!");
    static_assert(CheckRegTrait<RegIndexT, RegTraitNumOne>(),
        "RegTensor only suppoort RegTraitNumOne on current device!");
    static_assert(SupportBytes<ActualT, 1, 2, 4, 8>(),
        "DataCopyGatherB only support src & dst datatype b8/b16/b32/b64 on current device");
    if constexpr (sizeof(ActualT) == 1) {
        vgatherb((vector_s8 &)dstReg, (__ubuf__ int8_t *)baseAddr, index, mask);
    } else if constexpr (sizeof(ActualT) == 2) {
        vgatherb((vector_s16 &)dstReg, (__ubuf__ int16_t *)baseAddr, index, mask);
    } else if constexpr (sizeof(ActualT) == 4) {
        vgatherb((vector_s32 &)dstReg, (__ubuf__ int32_t *)baseAddr, index, mask);
    } else {
        vgatherb((vector_s64 &)dstReg, (__ubuf__ int64_t *)baseAddr, index, mask);
    }
}

template <typename T = DefaultType, typename IndexT = DefaultType, typename RegT, typename RegIndexT>
__simd_callee__ inline void DataCopyScatterB64Impl(__local_mem__ T *baseAddr, RegT &srcReg, RegIndexT &index, MaskReg &mask)
{
    // index b32
    if constexpr (CheckRegTrait<RegT, RegTraitNumOne>()) {
        MaskReg dstMask;
        RegTensor<uint32_t> oddIndex;
        RegTensor<uint32_t> evenIndex;
        RegTensor<uint32_t> oddReg;
        RegTensor<uint32_t> evenReg;
        RegTensor<uint32_t> dstReg0;
        RegTensor<uint32_t> dstReg1;
        MaskPack(dstMask, mask);
        MaskReg lowerMask = CreateMask<uint32_t, MaskPattern::VL32>();
        MaskReg preg = CreateMask<uint32_t, MaskPattern::ALL>();
        MaskAnd(dstMask, dstMask, lowerMask, preg);
        Muls(oddIndex, index, uint32_t(2), dstMask);
        Adds(evenIndex, oddIndex, uint32_t(1), dstMask);
        DeInterleave(dstReg0, dstReg1, (RegTensor<uint32_t> &)srcReg, (RegTensor<uint32_t> &)srcReg);
        vscatter(dstReg0, (__ubuf__ uint32_t *)baseAddr, oddIndex, dstMask);
        vscatter(dstReg1, (__ubuf__ uint32_t *)baseAddr, evenIndex, dstMask);
    } else {
        RegTensor<uint32_t> oddIndex;
        RegTensor<uint32_t> evenIndex;
        Muls(oddIndex, index, uint32_t(2), mask);
        Adds(evenIndex, oddIndex, uint32_t(1), mask);
        vscatter((RegTensor<uint32_t> &)srcReg.reg[0], (__ubuf__ uint32_t *)baseAddr, oddIndex, mask);
        vscatter((RegTensor<uint32_t> &)srcReg.reg[1], (__ubuf__ uint32_t *)baseAddr, evenIndex, mask);
    }
}

// vscatter
template <typename T = DefaultType, typename IndexT = DefaultType, typename RegT, typename RegIndexT>
__simd_callee__ inline void DataCopyScatterImpl(__local_mem__ T *baseAddr, RegT &srcReg, RegIndexT &index, MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    using ActualIndexT = typename RegIndexT::ActualT;
    static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
    static_assert(
        std::is_same_v<IndexT, DefaultType> || std::is_same_v<IndexT, ActualIndexT>, "IndexT type is not correct!");
    static_assert((sizeof(ActualT) == 1 && std::is_same_v<ActualIndexT, uint16_t>) ||
                      (sizeof(ActualT) == 2 && std::is_same_v<ActualIndexT, uint16_t>) ||
                      (sizeof(ActualT) == 4 && std::is_same_v<ActualIndexT, uint32_t>) ||
                      (sizeof(ActualT) == 8 && SupportType<ActualIndexT, uint32_t, uint64_t>()),
        "DataCopyScatter only support data type b8/b16/b32/b64"
        "with each index type is u16/u16/u32/(u32/u64) respectively on current device");
    // when index RegIndexT<b64, 1> only 32 element valid not support RegT<b64, 2> mode
    static_assert(!(sizeof(ActualT) == 8 && std::is_same_v<ActualIndexT, uint64_t> &&
                      CheckRegTrait<RegT, RegTraitNumTwo>() && CheckRegTrait<RegIndexT, RegTraitNumOne>()),
        "current data type is not supported on current device!");
    if constexpr (sizeof(ActualT) == 8) {
        if constexpr (std::is_same_v<ActualIndexT, uint32_t>) {
            DataCopyScatterB64Impl(baseAddr, srcReg, index, mask);
        } else if constexpr (std::is_same_v<ActualIndexT, uint64_t>) {
            if constexpr (CheckRegTrait<RegT, RegTraitNumOne>() && CheckRegTrait<RegIndexT, RegTraitNumOne>()) {
                RegTensor<uint32_t> lowIndex;
                RegTensor<uint32_t> highIndex;
                DeInterleave(lowIndex, highIndex, (RegTensor<uint32_t> &)index, (RegTensor<uint32_t> &)index);
                DataCopyScatterB64Impl(baseAddr, srcReg, lowIndex, mask);
            } else if constexpr ((CheckRegTrait<RegT, RegTraitNumOne>() &&
                                     CheckRegTrait<RegIndexT, RegTraitNumTwo>()) ||
                                 (CheckRegTrait<RegT, RegTraitNumTwo>() &&
                                     CheckRegTrait<RegIndexT, RegTraitNumTwo>())) {
                DataCopyScatterB64Impl(baseAddr, srcReg, (RegTensor<uint32_t> &)index.reg[0], mask);
            }
        }
    } else {
        vscatter(srcReg, baseAddr, index, mask);
    }
}

// pld
template <typename T, MaskDist dist = MaskDist::DIST_NORM>
__simd_callee__ inline void DataCopyImpl(MaskReg &mask, __local_mem__ T *srcUbAddr, AddrReg offset)
{
    static_assert(SupportBytes<T, 1, 2, 4, 8>(), "DataCopy only support type b8/b16/b32/b64 on current device");
    static_assert(SupportEnum<dist, MaskDist::DIST_NORM, MaskDist::DIST_US, MaskDist::DIST_DS>(),
        "DataCopy not support this dist on current device");
    constexpr auto distValue = std::integral_constant<::Dist, static_cast<::Dist>(dist)>();
    pld(mask, (__ubuf__ uint32_t *)srcUbAddr, offset, distValue);
}

// plds
template <typename T, MaskDist dist = MaskDist::DIST_NORM>
__simd_callee__ inline void DataCopyImpl(MaskReg &mask, __local_mem__ T *srcUbAddr)
{
    static_assert(SupportBytes<T, 1, 2, 4, 8>(), "DataCopy only support type b8/b16/b32/b64 on current device");
    static_assert(SupportEnum<dist, MaskDist::DIST_NORM, MaskDist::DIST_US, MaskDist::DIST_DS>(),
        "DataCopy not support this dist on current device");
    constexpr auto distValue = std::integral_constant<::Dist, static_cast<::Dist>(dist)>();
    plds(mask, (__ubuf__ uint32_t *)srcUbAddr, 0, distValue);
}

template <typename T, PostLiteral postMode, MaskDist dist = MaskDist::DIST_NORM>
__simd_callee__ inline void DataCopyImpl(MaskReg &mask, __local_mem__ T *&srcUbAddr, int32_t offset)
{
    static_assert(SupportBytes<T, 1, 2, 4, 8>(), "DataCopy only support type b8/b16/b32/b64 on current device");
    static_assert(SupportEnum<dist, MaskDist::DIST_NORM, MaskDist::DIST_US, MaskDist::DIST_DS>(),
        "DataCopy not support this dist on current device");
    constexpr auto distValue = std::integral_constant<::Dist, static_cast<::Dist>(dist)>();
    constexpr auto postValue = std::integral_constant<::Post, static_cast<::Post>(postMode)>();
    plds(mask, (__ubuf__ uint32_t *&)srcUbAddr, offset, distValue, postValue);
}

// pst
template <typename T, MaskDist dist = MaskDist::DIST_NORM>
__simd_callee__ inline void DataCopyImpl(__local_mem__ T *dstUbAddr, MaskReg &mask, AddrReg offset)
{
    static_assert(SupportBytes<T, 1, 2, 4, 8>(), "DataCopy only support type b8/b16/b32/b64 on current device");
    static_assert(SupportEnum<dist, MaskDist::DIST_NORM, MaskDist::DIST_PACK>(),
        "DataCopy not support this dist on current device");
    constexpr auto distValue = std::integral_constant<::Dist, static_cast<::Dist>(dist)>();
    pst(mask, (__ubuf__ uint32_t *)dstUbAddr, offset, distValue);
}

// psts
template <typename T, MaskDist dist = MaskDist::DIST_NORM>
__simd_callee__ inline void DataCopyImpl(__local_mem__ T *dstUbAddr, MaskReg &mask)
{
    static_assert(SupportBytes<T, 1, 2, 4, 8>(), "DataCopy only support type b8/b16/b32/b64 on current device");
    static_assert(SupportEnum<dist, MaskDist::DIST_NORM, MaskDist::DIST_PACK>(),
        "DataCopy not support this dist on current device");
    constexpr auto distValue = std::integral_constant<::Dist, static_cast<::Dist>(dist)>();
    psts(mask, (__ubuf__ uint32_t *)dstUbAddr, 0, distValue);
}

template <typename T, PostLiteral postMode, MaskDist dist = MaskDist::DIST_NORM>
__simd_callee__ inline void DataCopyImpl(__local_mem__ T *&dstUbAddr, MaskReg &mask, int32_t offset)
{
    static_assert(SupportBytes<T, 1, 2, 4, 8>(), "DataCopy only support type b8/b16/b32/b64 on current device");
    static_assert(SupportEnum<dist, MaskDist::DIST_NORM, MaskDist::DIST_PACK>(),
        "DataCopy not support this dist on current device");
    constexpr auto distValue = std::integral_constant<::Dist, static_cast<::Dist>(dist)>();
    constexpr auto postValue = std::integral_constant<::Post, static_cast<::Post>(postMode)>();
    psts(mask, (__ubuf__ uint32_t *&)dstUbAddr, offset, distValue, postValue);
}

template <typename T>
__simd_callee__ inline void DataCopyUnAlignImpl(__local_mem__ T *&dstUbAddr, MaskReg &mask, UnalignReg &ureg)
{
    static_assert(SupportBytes<T, 2, 4>(), "DataCopy only support type b16/b32 on current device");
    if constexpr (sizeof(T) == 2) {
        pstu(ureg, mask, (__ubuf__ uint16_t *&)dstUbAddr);
    } else if constexpr (sizeof(T) == 4) {
        pstu(ureg, mask, (__ubuf__ uint32_t *&)dstUbAddr);
    }
}
} // namespace MicroAPI
} // namespace AscendC
#endif // ASCENDC_MODULE_MICRO_DATACOPY_IMPL_H
